{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Copyright (c) Microsoft Corporation. All rights reserved.*  \n",
    "\n",
    "*Licensed under the MIT License.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Natural Language Inference on MultiNLI Dataset using Transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Before You Start\n",
    "\n",
    "It takes about 4 hours to fine-tune the `bert-large-cased` model on a Standard_NC24rs_v3 Azure Data Science Virtual Machine with 4 NVIDIA Tesla V100 GPUs. \n",
    "> **Tip:** If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n",
    "\n",
    "\n",
    "If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n",
    "QUICK_RUN = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "In this notebook, we demostrate fine-tuning pretrained transformer models to perform Natural Language Inference (NLI). We use the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral.   \n",
    "To classify a sentence pair, we concatenate the tokens in both sentences and separate the sentences by the special [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.The NLI task essentially becomes a sequence classification task. For example, the figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. \n",
    "<img src=\"https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG\">\n",
    "\n",
    "We compare the training time and performance of bert-large-cased and xlnet-large-cased. The model used can be set in the **Configurations** section. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import sys, os\n",
    "nlp_path = os.path.abspath('../../')\n",
    "if nlp_path not in sys.path:\n",
    "    sys.path.insert(0, nlp_path)\n",
    "\n",
    "import scrapbook as sb\n",
    "\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "import torch\n",
    "\n",
    "from utils_nlp.models.transformers.sequence_classification import Processor, SequenceClassifier\n",
    "from utils_nlp.dataset.multinli import load_pandas_df\n",
    "from utils_nlp.common.pytorch_utils import dataloader_from_dataset\n",
    "from utils_nlp.common.timer import Timer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To see all the model supported by `SequenceClassifier`, call the `list_supported_models` method.  \n",
    "**Note**: Although `SequenceClassifier` supports distilbert for single sequence classification, distilbert doesn't support sentence pair classification and can not be used in this notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SequenceClassifier.list_supported_models()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "MODEL_NAME = \"bert-large-cased\"\n",
    "TO_LOWER = False\n",
    "BATCH_SIZE = 16\n",
    "\n",
    "# MODEL_NAME = \"xlnet-large-cased\"\n",
    "# TO_LOWER = False\n",
    "# BATCH_SIZE = 16\n",
    "\n",
    "TRAIN_DATA_USED_FRACTION = 1\n",
    "DEV_DATA_USED_FRACTION = 1\n",
    "NUM_EPOCHS = 2\n",
    "WARMUP_STEPS= 2500\n",
    "\n",
    "if QUICK_RUN:\n",
    "    TRAIN_DATA_USED_FRACTION = 0.001\n",
    "    DEV_DATA_USED_FRACTION = 0.01\n",
    "    NUM_EPOCHS = 1\n",
    "    WARMUP_STEPS= 10\n",
    "\n",
    "if not torch.cuda.is_available():\n",
    "    BATCH_SIZE = BATCH_SIZE/2\n",
    "\n",
    "RANDOM_SEED = 42\n",
    "\n",
    "# model configurations\n",
    "MAX_SEQ_LENGTH = 128\n",
    "\n",
    "# optimizer configurations\n",
    "LEARNING_RATE= 5e-5\n",
    "\n",
    "# data configurations\n",
    "TEXT_COL_1 = \"sentence1\"\n",
    "TEXT_COL_2 = \"sentence2\"\n",
    "LABEL_COL = \"gold_label\"\n",
    "LABEL_COL_NUM = \"gold_label_num\"\n",
    "\n",
    "CACHE_DIR = TemporaryDirectory().name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data\n",
    "The MultiNLI dataset comes with three subsets: train, dev_matched, dev_mismatched. The dev_matched dataset are from the same genres as the train dataset, while the dev_mismatched dataset are from genres not seen in the training dataset.   \n",
    "The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"train\")\n",
    "dev_df_matched = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev_matched\")\n",
    "dev_df_mismatched = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev_mismatched\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_df_matched = dev_df_matched.loc[dev_df_matched['gold_label'] != '-']\n",
    "dev_df_mismatched = dev_df_mismatched.loc[dev_df_mismatched['gold_label'] != '-']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Training dataset size: {}\".format(train_df.shape[0]))\n",
    "print(\"Development (matched) dataset size: {}\".format(dev_df_matched.shape[0]))\n",
    "print(\"Development (mismatched) dataset size: {}\".format(dev_df_mismatched.shape[0]))\n",
    "print()\n",
    "print(train_df[['gold_label', 'sentence1', 'sentence2']].head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample\n",
    "train_df = train_df.sample(frac=TRAIN_DATA_USED_FRACTION).reset_index(drop=True)\n",
    "dev_df_matched = dev_df_matched.sample(frac=DEV_DATA_USED_FRACTION).reset_index(drop=True)\n",
    "dev_df_mismatched = dev_df_mismatched.sample(frac=DEV_DATA_USED_FRACTION).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_encoder = LabelEncoder()\n",
    "train_labels = label_encoder.fit_transform(train_df[LABEL_COL])\n",
    "train_df[LABEL_COL_NUM] = train_labels \n",
    "num_labels = len(np.unique(train_labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenize and Preprocess\n",
    "Before training, we tokenize and preprocess the sentence texts to convert them into the format required by transformer model classes.  \n",
    "The `dataset_from_dataframe` method of the `Processor` class performs the following preprocessing steps and returns a Pytorch `DataSet`\n",
    "* Tokenize input texts using the tokenizer of the pre-trained model specified by `model_name`. \n",
    "* Convert the tokens into token indices corresponding to the tokenizer's vocabulary.\n",
    "* Pad or truncate the token lists to the specified max length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = Processor(model_name=MODEL_NAME, cache_dir=CACHE_DIR, to_lower=TO_LOWER)\n",
    "\n",
    "train_dataset = processor.dataset_from_dataframe(\n",
    "    df=train_df,\n",
    "    text_col=TEXT_COL_1,\n",
    "    label_col=LABEL_COL_NUM,\n",
    "    text2_col=TEXT_COL_2,\n",
    "    max_len=MAX_SEQ_LENGTH,\n",
    ")\n",
    "dev_dataset_matched = processor.dataset_from_dataframe(\n",
    "    df=dev_df_matched,\n",
    "    text_col=TEXT_COL_1,    \n",
    "    text2_col=TEXT_COL_2,\n",
    "    max_len=MAX_SEQ_LENGTH,\n",
    ")\n",
    "dev_dataset_mismatched = processor.dataset_from_dataframe(\n",
    "    df=dev_df_mismatched,\n",
    "    text_col=TEXT_COL_1,    \n",
    "    text2_col=TEXT_COL_2,\n",
    "    max_len=MAX_SEQ_LENGTH,\n",
    ")\n",
    "\n",
    "train_dataloader = dataloader_from_dataset(\n",
    "    train_dataset, batch_size=BATCH_SIZE, shuffle=True\n",
    ")\n",
    "dev_dataloader_matched = dataloader_from_dataset(\n",
    "    dev_dataset_matched, batch_size=BATCH_SIZE, shuffle=False\n",
    ")\n",
    "dev_dataloader_mismatched = dataloader_from_dataset(\n",
    "    dev_dataset_mismatched, batch_size=BATCH_SIZE, shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and Predict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = SequenceClassifier(\n",
    "    model_name=MODEL_NAME, num_labels=num_labels, cache_dir=CACHE_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "with Timer() as t:\n",
    "    classifier.fit(\n",
    "            train_dataloader,\n",
    "            num_epochs=NUM_EPOCHS,\n",
    "            learning_rate=LEARNING_RATE,\n",
    "            warmup_steps=WARMUP_STEPS,\n",
    "        )\n",
    "\n",
    "print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict on Test Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with Timer() as t:\n",
    "    predictions_matched = classifier.predict(dev_dataloader_matched)\n",
    "print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with Timer() as t:\n",
    "    predictions_mismatched = classifier.predict(dev_dataloader_mismatched)\n",
    "print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_matched = label_encoder.inverse_transform(predictions_matched)\n",
    "print(classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "predictions_mismatched = label_encoder.inverse_transform(predictions_mismatched)\n",
    "print(classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare Model Performance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "|Model name|Training time|Scoring time|Matched F1|Mismatched F1|\n",
    "|:--------:|:-----------:|:----------:|:--------:|:-----------:|\n",
    "|xlnet-large-cased|5.15 hrs|0.11 hrs|0.887|0.890|\n",
    "|bert-large-cased|4.01 hrs|0.08 hrs|0.867|0.867|"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_matched_dict = classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3, output_dict=True)\n",
    "result_mismatched_dict = classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3, output_dict=True)\n",
    "sb.glue(\"matched_precision\", result_matched_dict[\"weighted avg\"][\"precision\"])\n",
    "sb.glue(\"matched_recall\", result_matched_dict[\"weighted avg\"][\"recall\"])\n",
    "sb.glue(\"matched_f1\", result_matched_dict[\"weighted avg\"][\"f1-score\"])\n",
    "sb.glue(\"mismatched_precision\", result_mismatched_dict[\"weighted avg\"][\"precision\"])\n",
    "sb.glue(\"mismatched_recall\", result_mismatched_dict[\"weighted avg\"][\"recall\"])\n",
    "sb.glue(\"mismatched_f1\", result_mismatched_dict[\"weighted avg\"][\"f1-score\"])"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "nlp_gpu",
   "language": "python",
   "name": "nlp_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
